Conversation
|
I'd like to do a bit of tidying up with this one but it can be merged if tests pass and people feel like it, I can just make a separate PR. |
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
8a1283e to
772b61b
Compare
There was a problem hiding this comment.
Overall looks good to me!
I have two remaining questions:
- This also bumps the Mooncake version, should this come before or after #164 ?
I might be getting this wrong, but I am kind of confused now about why we arezero!ing thedargs. Aren't we supposed to accumulate the contributions of the gradient into that, not setting them to? I.e., should this not just bedarg1 += zero(darg1), or effectively we don't have to do that?
[EDIT] ignore this last part, I'm stupid.
facdda6 to
4528a73
Compare
| function copy_tangent(var::Mooncake.CoDual, Δargs) | ||
| dargs = make_mooncake_fdata(deepcopy(Δargs)) | ||
| copyto!(Mooncake.tangent(var), dargs) | ||
| return | ||
| end | ||
|
|
||
| function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) | ||
| dargs = make_mooncake_fdata.(deepcopy(Δargs)) | ||
| for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) | ||
| if var_tangent isa Mooncake.FData | ||
| for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) | ||
| copyto!(var_f, darg_f) | ||
| end | ||
| else | ||
| copyto!(var_tangent, darg) | ||
| end | ||
| end | ||
| return | ||
| end |
There was a problem hiding this comment.
Should these functions have been called copy_tangent! for consistency with Julia naming guidelines?
| end | ||
|
|
||
| function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) | ||
| function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) |
There was a problem hiding this comment.
What is the role of Δargs and ȳ. Is Δargs is just the allocated space for the shadow variables of args, whereas ȳ contains the actual cotangents? But in most cases they are the same, and inplace_out contains Δargs as tangents and copy_tangent(inplace_out, ȳ) ends up copying ȳ into Δargs, i.e. into itself?
There was a problem hiding this comment.
For clarity, this is just a question to make sure I understand. I don't take an issue with copying into itself in these test methods.
| dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] | ||
| @test dA_inplace_ ≈ dA_copy_ | ||
| @test copy_args == inplace_args | ||
| if dargs_copy isa Tuple |
There was a problem hiding this comment.
What do the dargs_... variables contain at the end of the _get_..._derivative ? They are all set to zero in the pullback, right? Does it make sense to just test that both dargs_copy and dargs_inplace are both zero, rather than simply testing that they are equal (which is true if they are both zero)?
Added more tests to
test_pullbacks_matchto make sure the state of the arguments is restored, and the final argument derivatives match between inplace and non in place methods.Unfortunately, the Mooncake FD tester doesn't work well for our functions, because
Abecomes a scratch space, and the inputs are also the outputs (so get incremented twice under the FD scheme).